import argparse
import json
import os
import re
import time

import paddle
from paddleformers.trainer import strtobool
from paddleformers.transformers.configuration_utils import PretrainedConfig
from paddleformers.transformers.model_utils import shard_checkpoint
from paddleformers.utils.env import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
from paddleformers.utils.log import logger
from safetensors.numpy import save_file as safe_save_file

from fastdeploy.input.ernie_tokenizer import ErnieBotTokenizer
from fastdeploy.model_executor.layers.utils import get_tensor
from fastdeploy.model_executor.load_weight_utils import (
    get_all_safetensors,
    safetensors_weights_iterator,
)


def parse_arguments():
    """
    parse_arguments
    """
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--model_name_or_path",
        default=None,
        required=True,
        help="The directory of model.",
    )

    parser.add_argument(
        "--output_dir",
        default="merged_output",
        required=True,
        help="The directory of merged model output.",
    )

    parser.add_argument(
        "--safe_serialization",
        type=strtobool,
        default="True",
        help="Whether merge the model into safetensors format.",
    )

    return parser.parse_args()


def reorder():
    def fn(weight):
        from paddle.nn.quant import weight_quantize

        quant_weight, _ = weight_quantize(weight.cuda(), algo="w4a8", arch=80)
        return quant_weight.cpu()

    return fn


def deal_in_scale():
    def fn(in_scale):
        processed_in_scale = 1 / in_scale
        return processed_in_scale

    return fn


def deal_weight_scale():
    def fn(weight_scale, processed_in_scale):
        processed_weight_scale = weight_scale / (127 * 112) / processed_in_scale
        return processed_weight_scale

    return fn


# tmp support w4a8
def deal_quant(state_dict, save_state_dict):
    w4a8_mapping = [
        # pattern,fn
        (r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.activation_scale", deal_in_scale()),
        (r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.weight_scale", deal_weight_scale()),
        (r"layers\.(\d+)\.mlp\.experts\.(\d+)\.([^.]+)\.quant_weight", reorder()),
    ]
    for pattern, fn in w4a8_mapping:
        for key in list(state_dict.keys()):
            # print(f"deal {key}")
            match = re.search(pattern, key)
            if match:
                # print(f"{key} is match")
                weight_or_scale = state_dict.pop(key)
                if "weight_scale" in key:
                    in_scale_key = key.replace("weight_scale", "activation_scale")
                    in_scale = save_state_dict[in_scale_key]
                    save_state_dict[key] = fn(weight_or_scale, in_scale)
                else:
                    save_state_dict[key] = fn(weight_or_scale)


def save_safetensors(state_dict, args):
    """
    save_safetensors
    """
    logger.info("Move to numpy.")
    for k in list(state_dict.keys()):
        if isinstance(state_dict[k], paddle.Tensor):
            state_dict[k] = state_dict.pop(k).cpu().numpy()

    logger.info("Save safetensors files.")
    shards, index = shard_checkpoint(
        state_dict,
        max_shard_size="5GB",
        weights_name=SAFE_WEIGHTS_NAME,
        shard_format="naive",
    )
    for shard_file, shard in shards.items():
        save_file = os.path.join(args.output_dir, shard_file)
        logger.info(f"Saving {save_file}")
        safe_save_file(shard, save_file, metadata={"format": "np"})

    save_index_file = os.path.join(args.output_dir, SAFE_WEIGHTS_INDEX_NAME)
    with open(save_index_file, "w", encoding="utf-8") as f:
        content = json.dumps(index, indent=2) + "\n"
        f.write(content)


def main():
    """
    main
    """
    args = parse_arguments()
    pretrained_config, _ = PretrainedConfig.get_config_dict(args.model_name_or_path)
    pretrained_config = PretrainedConfig.from_dict(pretrained_config)
    vocab_file_names = [
        "tokenizer.model",
        "spm.model",
        "ernie_token_100k.model",
    ]
    for i in range(len(vocab_file_names)):
        if os.path.exists(os.path.join(args.model_name_or_path, vocab_file_names[i])):
            ErnieBotTokenizer.resource_files_names["vocab_file"] = vocab_file_names[i]
            break
    tokenizer = ErnieBotTokenizer.from_pretrained(args.model_name_or_path)
    _, safetensor_files = get_all_safetensors(args.model_name_or_path)
    weights_iterator = safetensors_weights_iterator(safetensor_files)
    state_dict = {}
    save_state_dict = {}
    start = time.perf_counter()
    for k, v in weights_iterator:
        state_dict[k] = get_tensor(v).cpu()
    end = time.perf_counter()
    logger.info("Finish Quantize.")
    logger.info(f"load and quantize took : {end - start:.6f} seconds")
    deal_quant(state_dict, save_state_dict)
    for key in list(state_dict.keys()):
        save_state_dict[key] = state_dict.pop(key)
    logger.info("Begin to save model")
    os.makedirs(args.output_dir, exist_ok=True)
    start = time.perf_counter()
    if not args.safe_serialization:
        paddle.save(
            save_state_dict,
            os.path.join(args.output_dir, "model_state.pdparams"),
        )
    else:
        save_safetensors(save_state_dict, args)
    pretrained_config.is_permuted = True
    pretrained_config.save_pretrained(args.output_dir)
    tokenizer.save_pretrained(args.output_dir)
    end = time.perf_counter()
    logger.info(f"save model took: {end - start:.6f} seconds")
    logger.info("Finish.")


if __name__ == "__main__":
    main()
